﻿using Cyss.Core;
using Microsoft.Data.SqlClient;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.ChangeTracking;
using Microsoft.EntityFrameworkCore.Storage;
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Data;
using System.IO;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Text;
using System.Text.RegularExpressions;

namespace Cyss.Core.Repository.EF
{
    /// <summary>
    /// Represents database context extensions
    /// </summary>
    public static class DbContextExtensions
    {
        #region Fields

        private static string databaseName;
        private static readonly ConcurrentDictionary<string, string> tableNames = new ConcurrentDictionary<string, string>();
        private static readonly ConcurrentDictionary<string, IEnumerable<(string, int?)>> columnsMaxLength = new ConcurrentDictionary<string, IEnumerable<(string, int?)>>();
        private static readonly ConcurrentDictionary<string, IEnumerable<(string, decimal?)>> decimalColumnsMaxValue = new ConcurrentDictionary<string, IEnumerable<(string, decimal?)>>();

        #endregion

        #region Utilities

        /// <summary>
        /// Loads a copy of the entity using the passed function
        /// </summary>
        /// <typeparam name="TEntity">Entity type</typeparam>
        /// <param name="context">Database context</param>
        /// <param name="entity">Entity</param>
        /// <param name="getValuesFunction">Function to get the values of the tracked entity</param>
        /// <returns>Copy of the passed entity</returns>
        private static TEntity LoadEntityCopy<TEntity>(IDbContext context, TEntity entity, Func<EntityEntry<TEntity>, PropertyValues> getValuesFunction)
            where TEntity : BaseEntity
        {
            if (context == null)
                throw new ArgumentNullException(nameof(context));

            //try to get the EF database context
            if (!(context is DbContext dbContext))
                throw new InvalidOperationException("Context does not support operation");

            //try to get the entity tracking object
            var entityEntry = dbContext.ChangeTracker.Entries<TEntity>().FirstOrDefault(entry => entry.Entity == entity);
            if (entityEntry == null)
                return null;

            //get a copy of the entity
            var entityCopy = getValuesFunction(entityEntry)?.ToObject() as TEntity;

            return entityCopy;
        }

        /// <summary>
        /// Get SQL commands from the script
        /// </summary>
        /// <param name="sql">SQL script</param>
        /// <returns>List of commands</returns>
        private static IList<string> GetCommandsFromScript(string sql)
        {
            var commands = new List<string>();

            //origin from the Microsoft.EntityFrameworkCore.Migrations.SqlServerMigrationsSqlGenerator.Generate method
            sql = Regex.Replace(sql, @"\\\r?\n", string.Empty);
            var batches = Regex.Split(sql, @"^\s*(GO[ \t]+[0-9]+|GO)(?:\s+|$)", RegexOptions.IgnoreCase | RegexOptions.Multiline);

            for (var i = 0; i < batches.Length; i++)
            {
                if (string.IsNullOrWhiteSpace(batches[i]) || batches[i].StartsWith("GO", StringComparison.OrdinalIgnoreCase))
                    continue;

                var count = 1;
                if (i != batches.Length - 1 && batches[i + 1].StartsWith("GO", StringComparison.OrdinalIgnoreCase))
                {
                    var match = Regex.Match(batches[i + 1], "([0-9]+)");
                    if (match.Success)
                        count = int.Parse(match.Value);
                }

                var builder = new StringBuilder();
                for (var j = 0; j < count; j++)
                {
                    builder.Append(batches[i]);
                    if (i == batches.Length - 1)
                        builder.AppendLine();
                }

                commands.Add(builder.ToString());
            }

            return commands;
        }

        #endregion

        #region Methods

        /// <summary>
        /// Loads the original copy of the entity
        /// </summary>
        /// <typeparam name="TEntity">Entity type</typeparam>
        /// <param name="context">Database context</param>
        /// <param name="entity">Entity</param>
        /// <returns>Copy of the passed entity</returns>
        public static TEntity LoadOriginalCopy<TEntity>(this IDbContext context, TEntity entity) where TEntity : BaseEntity
        {
            return LoadEntityCopy(context, entity, entityEntry => entityEntry.OriginalValues);
        }

        /// <summary>
        /// Loads the database copy of the entity
        /// </summary>
        /// <typeparam name="TEntity">Entity type</typeparam>
        /// <param name="context">Database context</param>
        /// <param name="entity">Entity</param>
        /// <returns>Copy of the passed entity</returns>
        public static TEntity LoadDatabaseCopy<TEntity>(this IDbContext context, TEntity entity) where TEntity : BaseEntity
        {
            return LoadEntityCopy(context, entity, entityEntry => entityEntry.GetDatabaseValues());
        }

        /// <summary>
        /// Drop a plugin table
        /// </summary>
        /// <param name="context">Database context</param>
        /// <param name="tableName">Table name</param>
        public static void DropPluginTable(this IDbContext context, string tableName)
        {
            if (context == null)
                throw new ArgumentNullException(nameof(context));

            if (string.IsNullOrEmpty(tableName))
                throw new ArgumentNullException(nameof(tableName));

            //drop the table
            string dbScript = $"IF OBJECT_ID('{tableName}', 'U') IS NOT NULL DROP TABLE [{tableName}]";
            context.ExecuteSqlCommand(dbScript);
            context.SaveChanges();
        }

        /// <summary>
        /// Get table name of entity
        /// </summary>
        /// <typeparam name="TEntity">Entity type</typeparam>
        /// <param name="context">Database context</param>
        /// <returns>Table name</returns>
        public static string GetTableName<TEntity>(this IDbContext context) where TEntity : BaseEntity
        {
            if (context == null)
                throw new ArgumentNullException(nameof(context));

            //try to get the EF database context
            if (!(context is DbContext dbContext))
                throw new InvalidOperationException("Context does not support operation");

            var entityTypeFullName = typeof(TEntity).FullName;
            if (!tableNames.ContainsKey(entityTypeFullName))
            {
                //get entity type
                var entityType = dbContext.Model.FindRuntimeEntityType(typeof(TEntity));

                //get the name of the table to which the entity type is mapped
                tableNames.TryAdd(entityTypeFullName, entityType.GetTableName());
            }

            tableNames.TryGetValue(entityTypeFullName, out var tableName);

            return tableName;
        }

        /// <summary>
        /// Gets the maximum lengths of data that is allowed for the entity properties
        /// </summary>
        /// <typeparam name="TEntity">Entity type</typeparam>
        /// <param name="context">Database context</param>
        /// <returns>Collection of name - max length pairs</returns>
        public static IEnumerable<(string Name, int? MaxLength)> GetColumnsMaxLength<TEntity>(this IDbContext context) where TEntity : BaseEntity
        {
            if (context == null)
                throw new ArgumentNullException(nameof(context));

            //try to get the EF database context
            if (!(context is DbContext dbContext))
                throw new InvalidOperationException("Context does not support operation");

            var entityTypeFullName = typeof(TEntity).FullName;
            if (!columnsMaxLength.ContainsKey(entityTypeFullName))
            {
                //get entity type
                var entityType = dbContext.Model.FindEntityType(typeof(TEntity));

                //get property name - max length pairs
                columnsMaxLength.TryAdd(entityTypeFullName,
                    entityType.GetProperties().Select(property => (property.Name, property.GetMaxLength())));
            }

            columnsMaxLength.TryGetValue(entityTypeFullName, out var result);

            return result;
        }

        /// <summary>
        /// Get maximum decimal values
        /// </summary>
        /// <typeparam name="TEntity">Entity type</typeparam>
        /// <param name="context">Database context</param>
        /// <returns>Collection of name - max decimal value pairs</returns>
        public static IEnumerable<(string Name, decimal? MaxValue)> GetDecimalColumnsMaxValue<TEntity>(this IDbContext context)
            where TEntity : BaseEntity
        {
            if (context == null)
                throw new ArgumentNullException(nameof(context));

            //try to get the EF database context
            if (!(context is DbContext dbContext))
                throw new InvalidOperationException("Context does not support operation");

            var entityTypeFullName = typeof(TEntity).FullName;
            if (!decimalColumnsMaxValue.ContainsKey(entityTypeFullName))
            {
                //get entity type
                var entityType = dbContext.Model.FindEntityType(typeof(TEntity));

                //get entity decimal properties
                var properties = entityType.GetProperties().Where(property => property.ClrType == typeof(decimal));

                //return property name - max decimal value pairs
                decimalColumnsMaxValue.TryAdd(entityTypeFullName, properties.Select(property =>
                {
                    var mapping = new RelationalTypeMappingInfo(property);
                    if (!mapping.Precision.HasValue || !mapping.Scale.HasValue)
                        return (property.Name, null);

                    return (property.Name, new decimal?((decimal)Math.Pow(10, mapping.Precision.Value - mapping.Scale.Value)));
                }));
            }

            decimalColumnsMaxValue.TryGetValue(entityTypeFullName, out var result);

            return result;
        }

        /// <summary>
        /// Get database name
        /// </summary>
        /// <param name="context">Database context</param>
        /// <returns>Database name</returns>
        public static string DbName(this IDbContext context)
        {
            if (context == null)
                throw new ArgumentNullException(nameof(context));

            //try to get the EF database context
            if (!(context is DbContext dbContext))
                throw new InvalidOperationException("Context does not support operation");

            if (!string.IsNullOrEmpty(databaseName))
                return databaseName;

            //get database connection
            var dbConnection = dbContext.Database.GetDbConnection();

            //return the database name
            databaseName = dbConnection.Database;

            return databaseName;
        }

        /// <summary>
        /// Get database name
        /// </summary>
        /// <param name="context">Database context</param>
        /// <returns>Database name</returns>
        public static string ConnectionString(this IDbContext context)
        {
            if (context == null)
                throw new ArgumentNullException(nameof(context));

            //try to get the EF database context
            if (!(context is DbContext dbContext))
                throw new InvalidOperationException("Context does not support operation");


            //get database connection
            return dbContext.Database.GetDbConnection().ConnectionString;
        }

        /// <summary>
        /// Execute commands from the SQL script against the context database
        /// </summary>
        /// <param name="context">Database context</param>
        /// <param name="sql">SQL script</param>
        public static void ExecuteSqlScript(this IDbContext context, string sql)
        {
            if (context == null)
                throw new ArgumentNullException(nameof(context));

            var sqlCommands = GetCommandsFromScript(sql);
            foreach (string command in sqlCommands)
            {
                string formattableString = $"{command}";
                context.ExecuteSqlCommand(formattableString);
            }

        }

        public static IEnumerable<T> QuerySql<T>(this IDbContext context, string Sql, params SqlParameter[] cmdParms) where T : class, new()
        {
            return QueryList<T>(context.GetDbConnection(), Sql, cmdParms);
        }

        /// <summary>
        /// 执行查询语句，返回DataSet
        /// </summary>
        /// <param name="SQLString">查询语句</param>
        /// <returns>DataSet</returns>
        private static IList<T> QueryList<T>(IDbConnection dbConnection, string SQLString, params SqlParameter[] cmdParms) where T : class, new()
        {
            DataSet ds = new DataSet();
            SqlConnection connection = (SqlConnection)dbConnection;


            SqlCommand cmd = new SqlCommand();
            PrepareCommand(cmd, connection, null, SQLString, cmdParms);
            using (SqlDataAdapter da = new SqlDataAdapter(cmd))
            {

                try
                {
                    da.Fill(ds, "ds");
                    cmd.Parameters.Clear();
                }
                catch (Exception ex)
                {
                    throw new Exception(ex.Message);
                }
            }


            return ds.Tables[0].ToList<T>();
        }
        public static void PrepareCommand(SqlCommand cmd, SqlConnection conn, SqlTransaction trans, string cmdText, SqlParameter[] cmdParms)
        {
            if (conn.State != ConnectionState.Open)
                conn.Open();
            cmd.Connection = conn;
            cmd.CommandText = cmdText;
            if (trans != null)
                cmd.Transaction = trans;
            cmd.CommandType = CommandType.Text;//cmdType;
            if (cmdParms != null)
            {
                foreach (SqlParameter parameter in cmdParms)
                {
                    if ((parameter.Direction == ParameterDirection.InputOutput || parameter.Direction == ParameterDirection.Input) &&
                        (parameter.Value == null))
                    {
                        parameter.Value = DBNull.Value;
                    }
                    cmd.Parameters.Add(parameter);
                }
            }
        }
        /// <summary>
        /// Execute commands from a file with SQL script against the context database
        /// </summary>
        /// <param name="context">Database context</param>
        /// <param name="filePath">Path to the file</param>
        public static void ExecuteSqlScriptFromFile(this IDbContext context, string filePath)
        {
            if (context == null)
                throw new ArgumentNullException(nameof(context));

            if (!File.Exists(filePath))
                return;

            context.ExecuteSqlScript(File.ReadAllText(filePath));
        }




        #endregion


        #region

        #region 删除

        public static int Delete<T>(this IDbContext context, int Id) where T : BaseEntity
        {
            string strSql = $" delete from [{typeof(T).Name}] where Id='{Id}'";
            return context.ExecuteSqlCommand(strSql); ;
        }

        public static int Delete<T>(this IDbContext context, IEnumerable<int> Ids) where T : BaseEntity
        {
            string strSql = $" delete from [{typeof(T).Name}] where Id in ({Ids.ToWhereString()}) ";
            return context.ExecuteSqlCommand(strSql); ;
        }

        private static string ToWhereString(this IEnumerable<int> Ids)
        {
            StringBuilder stringBuilder = new StringBuilder();
            foreach (int id in Ids)
            {
                if (!string.IsNullOrWhiteSpace(stringBuilder.ToString()))
                {
                    stringBuilder.Append(",");
                }
                stringBuilder.Append(id);
            }
            return stringBuilder.ToString();
        }

        #endregion

        #region  批量插入  

        /// <summary>
        /// 
        /// </summary>
        /// <typeparam name="T"></typeparam>
        /// <param name="context"></param>
        /// <param name="list"></param>
        /// <param name="DistinctKeySelector"></param>
        /// <param name="dbTableName"></param>
        public static int BulkInsert<T>(this IDbContext context, IEnumerable<T> list, Expression<Func<T, object>> DistinctKeySelector = null, string dbTableName = "") where T : BaseEntity, new()
        {
            var dt = list.ToList().ToDataTable();
            dbTableName = string.IsNullOrWhiteSpace(dbTableName) ? typeof(T).Name : dbTableName;
            dt.TableName = dbTableName;
            if (DistinctKeySelector != null)
            {
                var fiels = BaseSearchModelExpressionHelper.GetFields(DistinctKeySelector);
                return context.SqlBulkCopyInsert(dt, fiels);
            }
            else
            {
                return context.SqlBulkCopyInsert(dt, null);
            }

        }
        public static void BulkInsert(this IDbContext context, List<List<string>> distincts = null, params IEnumerable<BaseEntity>[] Lists)
        {
            List<DataTable> databases = new List<DataTable>();
            foreach (var list in Lists)
            {
                var dt = list.ToDataTable();
                dt.TableName = list.GetListItemType().Name;
                databases.Add(dt);
            }
            if (Lists.Count() == 1)
            {
                context.SqlBulkCopyInsert(databases.First(), distincts?.FirstOrDefault());
            }
            else
            {
                context.SqlBulkCopyInserts(distincts, databases.ToArray());
            }
        }

        /// <summary>
        /// 批量更新实体到数据库
        /// </summary>
        /// <typeparam name="T"></typeparam>
        /// <param name="context"></param>
        /// <param name="list"></param>
        /// <param name="updateFieldPredicate">指定需要更新的字段</param>
        /// <param name="dbTableName"></param>
        public static void BulkUpdate<T, K>(this IDbContext context, IEnumerable<T> list, string dbTableName = "", Expression<Func<T, K>> keySelector = null) where T : BaseEntity, new()
        {
            var dt = list.ToList().ToDataTable();
            dbTableName = string.IsNullOrWhiteSpace(dbTableName) ? typeof(T).Name : dbTableName;
            SqlBulkCopyUpdate(dt, dbTableName, context.GetBulkUpdateSql<T, K>(dbTableName, keySelector), null);
        }


        /// <summary>
        /// 批量更新货插入(根据主键查找有相同数据更新,没有就插入)
        /// </summary>
        /// <typeparam name="T"></typeparam>
        /// <param name="context"></param>
        /// <param name="list"></param>
        /// <param name="dbTableName"></param>
        public static void BulkUpdateOrInsert<T>(this IDbContext context, IEnumerable<T> list, string dbTableName = "") where T : BaseEntity, new()
        {

        }
        /// <summary>
        /// 使用SqlBulkCopy方法向数据库中写入数据，一般用于大量数据写入，效率较高
        /// 注：使用此方法必须保证传如的DataTable每列列名和数据库一一对应
        /// </summary>
        /// <param name="datadt">要导入的数据源</param>
        /// <param name="dbTableName">要导入到数据库的数据表名</param>
        private static IEnumerable<int> SqlBulkCopyInserts(this IDbContext context, List<List<string>> distincts = null, params DataTable[] dts)
        {
            List<int> counts = new List<int>();
            List<DataFieldTable> dataFieldTables = new List<DataFieldTable>();
            foreach (var dt in dts)
            {
                var dataColumns = context.GetDataColumns(dt.TableName);
                dataFieldTables.Add(new DataFieldTable { DataFileds = dataColumns, TableName = dt.TableName });
            }


            using (var sqlConnection = (SqlConnection)context.GetDbConnection())
            {
                if (sqlConnection.State == ConnectionState.Closed)
                {
                    sqlConnection.Open();
                }


                using (SqlTransaction transaction = sqlConnection.BeginTransaction())
                {
                    using (SqlBulkCopy sqlbulkcopy = new SqlBulkCopy(sqlConnection, SqlBulkCopyOptions.Default, transaction))
                    {

                        try
                        {
                            foreach (var dt in dts)
                            {
                                var dataColumns = dataFieldTables.FirstOrDefault(x => x.TableName == dt.TableName).DataFileds;
                                sqlbulkcopy.ColumnMappings.Clear();
                                sqlbulkcopy.DestinationTableName = dt.TableName;
                                for (int i = 0; i < dt.Columns.Count; i++)
                                {
                                    string columnName = dt.Columns[i].ColumnName;
                                    var dataColumn = dataColumns.FirstOrDefault(x => x.Name == columnName);
                                    if (dataColumn == null)
                                    {
                                        continue;
                                    }
                                    if (dataColumn.IsIdentity)
                                    {
                                        continue;
                                    }

                                    sqlbulkcopy.ColumnMappings.Add(columnName, columnName);
                                }
                                sqlbulkcopy.WriteToServer(dt);
                                counts.Add(sqlbulkcopy.GetRowsCopied());
                            }
                            transaction.Commit();
                        }
                        catch (System.Exception ex)
                        {
                            transaction.Rollback();
                            if (sqlConnection.State == ConnectionState.Open)
                            {
                                sqlConnection.Close();
                            }

                            throw ex;
                        }
                    }
                }
            }
            return counts;

        }

        /// <summary>
        /// 使用SqlBulkCopy方法向数据库中写入数据，一般用于大量数据写入，效率较高
        /// 注：使用此方法必须保证传如的DataTable每列列名和数据库一一对应
        /// </summary>
        /// <param name="datadt">要导入的数据源</param>
        /// <param name="dbTableName">要导入到数据库的数据表名</param>
        private static int SqlBulkCopyInsert(this IDbContext context, DataTable dt, IEnumerable<string> distincts)
        {
            var con = (SqlConnection)context.GetDbConnection();
            using (SqlBulkCopy sqlbulkcopy = new SqlBulkCopy(con))
            {
                List<string> InsertColumns = new List<string>();
                var dataColumns = context.GetDataColumns(dt.TableName);
                sqlbulkcopy.ColumnMappings.Clear();
                string tempTableName = dt.TableName;

                //如果需要去重
                if (distincts != null && distincts.Count() > 0)
                {
                    tempTableName = "#" + tempTableName;
                }
                sqlbulkcopy.DestinationTableName = tempTableName;
                for (int i = 0; i < dt.Columns.Count; i++)
                {
                    string columnName = dt.Columns[i].ColumnName;
                    var dataColumn = dataColumns.FirstOrDefault(x => x.Name == columnName);
                    if (dataColumn == null)
                    {
                        continue;
                    }
                    if (dataColumn.IsIdentity)
                    {
                        continue;
                    }
                    InsertColumns.Add(columnName);
                    sqlbulkcopy.ColumnMappings.Add(columnName, columnName);
                }

                SqlCommand sqlCommand = new SqlCommand();
                if (distincts != null && distincts.Count() > 0)
                {
                    sqlCommand.CommandText = $"select top 0 * into {tempTableName} from {dt.TableName}";
                    sqlCommand.Connection = con;
                    sqlCommand.ExecuteNonQuery();

                    sqlbulkcopy.WriteToServer(dt);

                    sqlCommand.CommandText = GetBulkInsertSql(dt.TableName, tempTableName, InsertColumns, distincts);
                    sqlCommand.Connection = con;

                    return sqlCommand.ExecuteNonQuery();
                }
                else
                {
                    sqlbulkcopy.WriteToServer(dt);
                    return sqlbulkcopy.GetRowsCopied();
                }
            }
        }

        public static int GetRowsCopied(this SqlBulkCopy bulkCopy)
        {
            FieldInfo rowsCopiedField = null;
            if (rowsCopiedField == null)
            {
                rowsCopiedField = typeof(SqlBulkCopy).GetField("_rowsCopied", BindingFlags.NonPublic | BindingFlags.GetField | BindingFlags.Instance);
            }

            return (int)rowsCopiedField.GetValue(bulkCopy);
        }

        /// <summary>
        /// 使用SqlBulkCopy方法向数据库中写入数据，一般用于大量数据写入，效率较高
        /// 注：使用此方法必须保证传如的DataTable每列列名和数据库一一对应
        /// </summary>
        /// <param name="datadt">要导入的数据源</param>
        /// <param name="dbTableName">要导入到数据库的数据表名</param>
        private static int SqlBulkCopyUpdate(DataTable table, string tableName, string strUpdateSql, IDbConnection dbConnection)
        {
            try
            {
                SqlConnection conn = (SqlConnection)dbConnection;
                conn.Open();

                SqlBulkCopy bulkCopy = new SqlBulkCopy(conn);
                bulkCopy.DestinationTableName = "#" + tableName;
                foreach (DataColumn dc in table.Columns)
                {
                    bulkCopy.ColumnMappings.Add(dc.ColumnName, dc.ColumnName);//将table中的列与数据库表这的列一一对应
                }

                try
                {

                    SqlCommand sqlCommand = new SqlCommand();
                    sqlCommand.CommandText = $"select top 0 * into #{tableName} from {tableName}";
                    sqlCommand.Connection = conn;
                    sqlCommand.ExecuteNonQuery();

                    bulkCopy.WriteToServer(table);

                    sqlCommand.CommandText = strUpdateSql;
                    sqlCommand.Connection = conn;

                    return sqlCommand.ExecuteNonQuery();
                }
                catch (Exception ex)
                {
                    throw ex;
                }
                finally
                {
                    bulkCopy.Close();
                    conn.Close();
                }



            }
            catch (Exception ex)
            {
                throw ex;
            }
        }

        #endregion

        #endregion

        #region

        private static IEnumerable<DataField> GetDataColumns(this IDbContext context, string tableName)
        {
            return context.QuerySql<DataField>(GetTableColumnSql(tableName));
        }

        /// <summary>
        /// 
        /// </summary>
        /// <param name="tableName"></param>
        /// <param name="tempTableName"></param>
        /// <param name="distincts"></param>
        /// <returns></returns>
        private static string GetBulkInsertSql(string tableName, string tempTableName, IEnumerable<string> Columns, IEnumerable<string> Distincts)
        {
            StringBuilder stringBuilder = new StringBuilder();
            stringBuilder.Append($" insert into {tableName} ({string.Join(",", Columns)})");

            stringBuilder.Append(" select ");
            foreach (string field in Columns)
            {
                if (field != Columns.First())
                {
                    stringBuilder.Append(",");
                }
                stringBuilder.Append($" a.{field} ");
            }

            stringBuilder.Append($" from {tempTableName} a  left join  {tableName} b on ");

            foreach (string field in Distincts)
            {
                if (field != Distincts.First())
                {
                    stringBuilder.Append($" and ");
                }
                stringBuilder.Append($" a.{field}=b.{field} ");
            }
            stringBuilder.Append(" where b.Id is null ");

            return stringBuilder.ToString();
        }
        private static string GetBulkUpdateSql<T, K>(this IDbContext context, string tableName, Expression<Func<T, K>> keySelector)
        {
            var dataFileds = context.GetDataColumns(tableName);
            StringBuilder stringBuilder = new StringBuilder();
            stringBuilder.Append(" update UserGroup set ");

            var IdentityColumn = dataFileds.FirstOrDefault(x => x.IsIdentity);
            var PrimaryKeys = dataFileds.Where(x => x.IsPrimaryKey);

            //用于on 条件的字段
            List<DataField> whereFileds = new List<DataField>();

            if (IdentityColumn != null)
            {
                whereFileds.Add(IdentityColumn);
            }
            else
            {
                whereFileds.AddRange(PrimaryKeys);
            }

            bool isFirst = false;

            if (keySelector == null)
            {
                foreach (var dataFiled in dataFileds)
                {
                    //自增值和主键值都不会去修改
                    if (IdentityColumn != dataFiled || PrimaryKeys.Contains(dataFiled))
                    {
                        continue;
                    }
                    if (isFirst)
                    {
                        stringBuilder.Append(" , ");
                    }
                    stringBuilder.Append($" {dataFiled.Name}=a.{dataFiled.Name} ");
                    isFirst = true;
                }
            }
            else
            {
                stringBuilder.Append(ExpressionSetHelper.GetUpdateSetSqlByExpression(keySelector));
            }



            stringBuilder.Append(" from UserGroup a join #UserGroup b on ");
            isFirst = false;
            foreach (var dataFiled in whereFileds)
            {
                if (isFirst)
                {
                    stringBuilder.Append(" and ");
                }
                stringBuilder.Append($" a.{dataFiled.Name}=b.{dataFiled.Name} ");
                isFirst = true;
            }

            return stringBuilder.ToString();

        }

        private static string GetTableColumnSql(string tableName)
        {
            string sql = $@"SELECT t1.name Name,case when t4.id is null then cast(0 as bit) else cast(1 as bit) end as IsPrimaryKey,cast(t1.isnullable as bit) as IsNullable,t1.length as Length,
                                      case when COLUMNPROPERTY( t1.id,t1.name,'IsIdentity') = 1 then cast(1 as bit) else cast(0 as bit) end as IsIdentity ,t5.name [Type]
                                     ,cast(isnull(t6.value,'') as varchar(2000)) Details
                                    FROM SYSCOLUMNS t1
                                    left join SYSOBJECTS t2 on t2.parent_obj = t1.id AND t2.xtype = 'PK'
                                    left join SYSINDEXES t3 on t3.id = t1.id and t2.name = t3.name
                                    left join SYSINDEXKEYS t4 on t1.colid = t4.colid and t4.id = t1.id and t4.indid = t3.indid
                                    left join systypes t5 on t1.xtype=t5.xtype
                                    left join sys.extended_properties t6 on t1.id=t6.major_id and t1.colid=t6.minor_id
                                    left join SYSOBJECTS tb on tb.id=t1.id
                                    where tb.name='{tableName}' and t5.name<>'sysname'
                                    order by t1.colid asc";

            return sql;
        }

        #endregion
    }
}
